"""股票筛选与评估视图。"""
from datetime import date, datetime, timedelta
from typing import Dict, List, Optional, Tuple
import json
import sqlite3

import numpy as np
import pandas as pd
import streamlit as st

from app.features.evaluation import evaluate_factor
from app.features.factors import DEFAULT_FACTORS
from app.features.validation import check_data_sufficiency
from app.utils.config import get_config
from app.utils.data_access import DataBroker
from app.utils.db import db_session
from app.utils.logging import get_logger

LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "stock_eval"}


def _ensure_investment_pool_schema(conn: sqlite3.Connection) -> None:
    """Ensure investment_pool table has latest optional columns."""
    try:
        info = conn.execute("PRAGMA table_info(investment_pool)").fetchall()
    except sqlite3.Error:
        return

    columns = {
        (row["name"] if isinstance(row, sqlite3.Row) else row[1])
        for row in info
        if row is not None
    }

    if "name" not in columns:
        try:
            conn.execute("ALTER TABLE investment_pool ADD COLUMN name TEXT")
        except sqlite3.Error:
            pass
    if "industry" not in columns:
        try:
            conn.execute("ALTER TABLE investment_pool ADD COLUMN industry TEXT")
        except sqlite3.Error:
            pass
    if "created_at" not in columns:
        try:
            conn.execute(
                "ALTER TABLE investment_pool ADD COLUMN created_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now'))"
            )
        except sqlite3.Error:
            try:
                conn.execute("ALTER TABLE investment_pool ADD COLUMN created_at TEXT")
            except sqlite3.Error:
                pass


def _get_latest_trading_date() -> date:
    """获取数据库中的最新交易日期"""
    with db_session(read_only=True) as conn:
        result = conn.execute(
            """
            SELECT trade_date 
            FROM daily_basic 
            WHERE trade_date <= :today
            GROUP BY trade_date 
            ORDER BY trade_date DESC 
            LIMIT 1
            """,
            {"today": datetime.now().strftime("%Y%m%d")}
        ).fetchone()
        
        if result and result[0]:
            return datetime.strptime(str(result[0]), "%Y%m%d").date()
        return datetime.now().date() - timedelta(days=1)  # 如果查询失败才返回昨天


def _get_main_board_universe(trade_date: str) -> List[str]:
    """获取沪深主板的股票列表。"""
    try:
        with db_session(read_only=True) as conn:
            rows = conn.execute(
                """
                SELECT DISTINCT s.ts_code
                FROM stock_basic AS s
                JOIN daily AS d ON s.ts_code = d.ts_code
                WHERE s.market = '主板'
                  AND IFNULL(s.list_status, 'L') = 'L'
                  AND d.trade_date <= ?
                ORDER BY s.ts_code
                """,
                (trade_date,)
            ).fetchall()
        return [
            (row["ts_code"] if isinstance(row, sqlite3.Row) else row[0])
            for row in rows
            if row and (row["ts_code"] if isinstance(row, sqlite3.Row) else row[0])
        ]
    except sqlite3.Error as exc:
        LOGGER.warning("获取沪深主板股票列表失败 err=%s", exc, extra=LOG_EXTRA)
        return []


def _normalize_universe(universe: Optional[List[str]]) -> List[str]:
    """标准化股票代码列表，去重并转为大写。"""

    if not universe:
        return []
    normalized: Dict[str, None] = {}
    for code in universe:
        candidate = (code or "").strip().upper()
        if candidate and candidate not in normalized:
            normalized[candidate] = None
    return list(normalized.keys())

def render_stock_evaluation() -> None:
    """渲染股票筛选与评估页面。"""
    LOGGER = get_logger(__name__)
    LOG_EXTRA = {"stage": "stock_evaluation_ui"}
    
    st.subheader("股票筛选与评估")
    
    # 记录页面加载
    LOGGER.info("股票筛选与评估页面已加载", extra=LOG_EXTRA)
    
    # 1. 时间范围选择
    col1, col2 = st.columns(2)
    with col1:
        latest_date = _get_latest_trading_date()
        end_date = st.date_input(
            "评估截止日期",
            value=latest_date,
            help="选择评估的截止日期"
        )
    with col2:
        lookback_days = st.slider(
            "回溯天数",
            min_value=30,
            max_value=360,
            value=180,
            step=30,
            help="选择评估的历史数据长度"
        )
    start_date = end_date - timedelta(days=lookback_days)
    
    # 2. 因子选择
    st.markdown("##### 评估因子选择")
    factor_groups = {
        "动量类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("mom_")],
        "波动率类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("volat_")],
        "换手率类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("turn_")],
        "估值类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("val_")],
        "量价类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("volume_")],
        "市场类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("market_")]
    }
    
    # 定义默认选中的关键常用因子
    DEFAULT_SELECTED_FACTORS = {
        "mom_5",   # 5日动量
        "mom_20",  # 20日动量
        "mom_60",  # 60日动量
        "volat_20",  # 20日波动率
        "turn_5",   # 5日换手率
        "turn_20",  # 20日换手率
        "val_pe_score",  # PE评分
        "val_pb_score",  # PB评分
        "volume_ratio_score",  # 量比评分
        "risk_penalty"  # 风险惩罚项
    }
    
    selected_factors = []
    for group_name, factors in factor_groups.items():
        if factors:
            st.markdown(f"###### {group_name}")
            cols = st.columns(3)
            for i, factor in enumerate(factors):
                if cols[i % 3].checkbox(
                    factor.name,
                    value=factor.name in DEFAULT_SELECTED_FACTORS,
                    help=factor.description if hasattr(factor, 'description') else None
                ):
                    selected_factors.append(factor.name)
    
    if not selected_factors:
        st.warning("请至少选择一个评估因子")
        return
        
    # 3. 基本面筛选
    default_pe_range: Tuple[float, float] = (-100.0, 200.0)
    default_pb_range: Tuple[float, float] = (-10.0, 20.0)
    default_ps_range: Tuple[float, float] = (0.0, 50.0)
    default_mv_range: Tuple[float, float] = (0.0, 20000.0)

    st.markdown("##### 基本面筛选")
    st.caption("调整估值指标范围以筛选股票池，默认值保留全量范围。")

    def _range_changed(selected: Tuple[float, float], baseline: Tuple[float, float]) -> bool:
        return any(abs(sel - base) > 1e-6 for sel, base in zip(selected, baseline))

    col_pe, col_pb = st.columns(2)
    with col_pe:
        pe_range = st.slider(
            "市盈率 (PE, TTM) 区间",
            min_value=-100.0,
            max_value=200.0,
            value=default_pe_range,
            step=1.0,
            help="当市盈率不在该区间时予以剔除，负值表示纳入亏损公司。"
        )
    with col_pb:
        pb_range = st.slider(
            "市净率 (PB) 区间",
            min_value=-10.0,
            max_value=20.0,
            value=default_pb_range,
            step=0.1,
            help="排除市净率超出该区间的股票。"
        )
    col_ps, col_mv = st.columns(2)
    with col_ps:
        ps_range = st.slider(
            "市销率 (PS, TTM) 区间",
            min_value=0.0,
            max_value=50.0,
            value=default_ps_range,
            step=0.5,
            help="排除市销率超出该区间的股票。"
        )
    with col_mv:
        mv_range = st.slider(
            "总市值 (亿元) 区间",
            min_value=0.0,
            max_value=20000.0,
            value=default_mv_range,
            step=50.0,
            help="根据总市值范围筛选股票，单位为亿元。"
        )

    fundamental_filters: Dict[str, Tuple[float, float]] = {}
    if _range_changed(pe_range, default_pe_range):
        fundamental_filters["pe_ttm"] = pe_range
    if _range_changed(pb_range, default_pb_range):
        fundamental_filters["pb"] = pb_range
    if _range_changed(ps_range, default_ps_range):
        fundamental_filters["ps_ttm"] = ps_range
    if _range_changed(mv_range, default_mv_range):
        fundamental_filters["total_mv_billion"] = mv_range
    
    # 4. 股票池范围
    st.markdown("##### 股票池范围")
    pool_options = ["沪深主板", "沪深300", "中证500", "中证1000", "全部A股", "自定义"]
    pool_type = st.radio(
        "选择股票池",
        pool_options,
        index=0,  # 默认选择沪深主板
        horizontal=True
    )
    
    universe: Optional[List[str]] = None
    trade_date_str = end_date.strftime("%Y%m%d")
    broker = DataBroker()
    if pool_type == "自定义":
        custom_codes = st.text_area(
            "输入股票代码列表(每行一个)",
            help="请输入股票代码，每行一个，例如: 000001.SZ"
        )
        if custom_codes:
            universe = [
                code.strip()
                for code in custom_codes.split("\n")
                if code.strip()
            ]
    elif pool_type == "沪深主板":
        universe = _get_main_board_universe(trade_date_str)
    elif pool_type == "全部A股":
        universe = None
    else:
        index_code_map = {
            "沪深300": "000300.SH",
            "中证500": "000905.SH",
            "中证1000": "000852.SH"
        }
        index_code = index_code_map.get(pool_type)
        if index_code:
            universe = broker.get_index_stocks(
                index_code,
                trade_date_str
            )
    universe = _normalize_universe(universe)
    if universe == []:
        universe = None
            
    # 5. 评估结果
    
    # 初始化会话状态
    if 'evaluation_results' not in st.session_state:
        st.session_state.evaluation_results = None
    if 'evaluation_status' not in st.session_state:
        st.session_state.evaluation_status = 'idle'  # idle, running, completed, error
    if 'current_factor' not in st.session_state:
        st.session_state.current_factor = ''
    if 'progress' not in st.session_state:
        st.session_state.progress = 0
    
    # 同步评估函数
    def run_evaluation_sync():
        try:
            # 记录评估开始
            LOGGER.info(
                "开始因子评估 因子数量=%s 评估日期=%s 至 %s",
                len(selected_factors),
                start_date,
                end_date,
                extra=LOG_EXTRA
            )
            
            st.session_state.evaluation_status = 'running'
            st.session_state.pop('evaluation_error', None)
            results = []
            
            for i, factor_name in enumerate(selected_factors):
                st.session_state.current_factor = factor_name
                st.session_state.progress = ((i + 1) / len(selected_factors)) * 100
                
                performance = evaluate_factor(
                    factor_name,
                    start_date,
                    end_date,
                    universe=universe
                )
                results.append({
                    "因子": factor_name,
                    "IC均值": performance.ic_mean,
                    "RankIC均值": performance.rank_ic_mean,
                    "IC信息比率": performance.ic_ir,
                    "夏普比率": performance.sharpe_ratio,
                    "换手率": performance.turnover_rate,
                    "有效样本数": performance.sample_size,
                })
            
            st.session_state.evaluation_results = results
            st.session_state.evaluation_status = 'completed'
            st.session_state.progress = 100
            
        except Exception as e:
            st.session_state.evaluation_status = 'error'
            st.session_state.evaluation_error = str(e)
    
    # 显示进度
    if st.session_state.evaluation_status == 'running':
        st.info(f"正在评估因子: {st.session_state.current_factor}")
        st.progress(st.session_state.progress / 100)
    elif st.session_state.evaluation_status == 'completed':
        st.success("因子评估完成！")
    elif st.session_state.evaluation_status == 'error':
        st.error(f"评估失败: {st.session_state.evaluation_error}")
    
    # 开始评估按钮
    if st.button("开始评估", disabled=not selected_factors or st.session_state.evaluation_status == 'running'):
        # 重置状态
        st.session_state.evaluation_results = None
        st.session_state.evaluation_status = 'running'
        st.session_state.progress = 0
        
        # 直接调用同步评估函数
        run_evaluation_sync()
    
    # 显示结果
    if st.session_state.evaluation_results:
        results = st.session_state.evaluation_results
        
        st.markdown("##### 因子评估结果")
        result_df = pd.DataFrame(results)
        if not result_df.empty:
            display_df = result_df.copy()
            for col in ["IC均值", "RankIC均值", "IC信息比率"]:
                if col in display_df:
                    display_df[col] = display_df[col].map(lambda v: f"{v:.4f}")
            if "夏普比率" in display_df:
                display_df["夏普比率"] = display_df["夏普比率"].map(
                    lambda v: "N/A" if v is None else f"{v:.4f}"
                )
            if "换手率" in display_df:
                display_df["换手率"] = display_df["换手率"].map(
                    lambda v: "N/A" if v is None else f"{v * 100:.1f}%"
                )
            if "有效样本数" in display_df:
                display_df["有效样本数"] = display_df["有效样本数"].astype(int)
            st.dataframe(
                display_df,
                hide_index=True,
                width="stretch"
            )
        else:
            st.info("未产生任何因子评估结果。")
        
        # 绘制IC均值分布
        factor_names = result_df["因子"].tolist() if not result_df.empty else []
        ic_series = result_df["IC均值"].astype(float) if not result_df.empty else pd.Series(dtype=float)
        if "有效样本数" in result_df:
            sample_series = result_df["有效样本数"].astype(int)
            ic_series = ic_series.where(sample_series > 0)
        ic_means = ic_series.tolist()
        chart_df = pd.DataFrame({
            "因子": factor_names,
            "IC均值": ic_means
        })
        st.bar_chart(chart_df.set_index("因子"))

        if not factor_names:
            st.info("暂无足够的 IC 数据，无法生成股票评分。")
            return

        ic_array = np.array(ic_means, dtype=float)
        usable_indices = [idx for idx, value in enumerate(ic_array) if np.isfinite(value)]
        if not usable_indices:
            st.info("所有因子 IC 均值均不可用，请先补充因子数据再评估。")
            return

        usable_factors = [factor_names[idx] for idx in usable_indices]
        usable_ic = ic_array[usable_indices]

        dropped_factors = [factor_names[idx] for idx, value in enumerate(ic_array) if not np.isfinite(value)]
        if dropped_factors:
            st.caption(f"已忽略缺少有效 IC 数据的因子：{', '.join(dropped_factors)}")

        with st.spinner("正在生成股票评分..."):
            if np.all(np.abs(usable_ic) <= 1e-9):
                factor_weights = np.full(usable_ic.shape, 1.0 / usable_ic.size, dtype=float)
                LOGGER.info("有效因子IC均值均为零，使用均匀权重", extra=LOG_EXTRA)
            else:
                abs_sum = float(np.sum(np.abs(usable_ic)))
                if abs_sum <= 1e-9:
                    factor_weights = np.full(usable_ic.shape, 1.0 / usable_ic.size, dtype=float)
                    LOGGER.info("有效因子IC均值绝对和过小，使用均匀权重", extra=LOG_EXTRA)
                else:
                    factor_weights = usable_ic / abs_sum
                    LOGGER.info("使用IC均值作为权重: %s", factor_weights.tolist(), extra=LOG_EXTRA)

            weight_mask = np.abs(factor_weights) > 1e-6
            filtered_factors = [name for name, flag in zip(usable_factors, weight_mask) if flag]
            filtered_weights = [float(weight) for weight, flag in zip(factor_weights, weight_mask) if flag]

            if not filtered_factors:
                st.info("因子权重有效值均为零，无法生成股票评分。")
                return
            if len(filtered_factors) < len(usable_factors):
                dropped_names = [name for name, flag in zip(usable_factors, weight_mask) if not flag]
                LOGGER.info("已忽略权重为零的因子：%s", dropped_names, extra=LOG_EXTRA)

            scores = _calculate_stock_scores(
                universe,
                filtered_factors,
                end_date,
                filtered_weights,
                valuation_filters=fundamental_filters if fundamental_filters else None,
            )

        if scores:
            st.markdown("##### 股票综合评分 (Top 20)")
            score_df = pd.DataFrame(scores).sort_values(
                "综合评分",
                ascending=False
            )
            top_df = score_df.head(20).reset_index(drop=True)
            display_scores = top_df.copy()
            display_scores["综合评分"] = display_scores["综合评分"].map(lambda v: f"{v:.4f}")
            st.dataframe(
                display_scores,
                hide_index=True,
                width="stretch"
            )

            if st.button("将Top 20股票加入股票池"):
                _add_to_stock_pool(top_df, end_date)
                st.success("已成功将选中股票加入股票池！")
        else:
            st.info("无法根据当前因子权重生成有效的股票评分结果。")


def _calculate_stock_scores(
    universe: Optional[List[str]],
    factors: List[str],
    eval_date: date,
    factor_weights: List[float],
    valuation_filters: Optional[Dict[str, Tuple[float, float]]] = None
) -> List[Dict[str, object]]:
    """计算股票的综合评分。"""
    LOGGER = get_logger(__name__)
    LOG_EXTRA = {"stage": "stock_evaluation"}
    
    if not factors:
        LOGGER.warning("因子列表为空，无法计算股票评分", extra=LOG_EXTRA)
        return []
    if len(factors) != len(factor_weights):
        LOGGER.error(
            "因子数量与权重数量不一致 factors=%s weights=%s",
            len(factors),
            len(factor_weights),
            extra=LOG_EXTRA,
        )
        return []

    broker = DataBroker()
    trade_date_str = eval_date.strftime("%Y%m%d")
    
    # 记录评估开始
    LOGGER.info(
        "开始股票评估评估日期=%s 因子数量=%d 权重=%s",
        eval_date.strftime("%Y-%m-%d"),
        len(factors),
        factor_weights,
        extra=LOG_EXTRA
    )
    
    # 标准化权重
    weights = np.array(factor_weights, dtype=float)
    abs_sum = np.sum(np.abs(weights))
    if abs_sum > 0:  # 避免除以零
        weights = weights / abs_sum
    else:
        # 如果所有权重都是零，则使用均匀分布
        weights = np.ones_like(weights) / len(weights)
    
    # 获取所有股票的因子值
    stocks = universe or broker.get_all_stocks(trade_date_str)
    if not stocks:
        LOGGER.warning("股票列表为空，无法生成评分", extra=LOG_EXTRA)
        return []
    
    # 记录股票列表信息
    LOGGER.info(
        "获取股票列表 universe_size=%d total_stocks=%d",
        len(universe) if universe else 0,
        len(stocks),
        extra=LOG_EXTRA
    )

    results = []
    
    evaluated_count = 0
    skipped_count = 0
    factor_fields = [f"factors.{name}" for name in factors]
    valuation_field_map = {
        "pe_ttm": "daily_basic.pe_ttm",
        "pb": "daily_basic.pb",
        "ps_ttm": "daily_basic.ps_ttm",
        "total_mv_billion": "daily_basic.total_mv",
    }
    extra_fields: List[str] = []
    if valuation_filters:
        LOGGER.info("估值筛选条件: %s", valuation_filters, extra=LOG_EXTRA)
        for key in valuation_filters:
            field_name = valuation_field_map.get(key)
            if field_name and field_name not in extra_fields:
                extra_fields.append(field_name)
    fetch_fields = factor_fields + [field for field in extra_fields if field not in factor_fields]
    
    for ts_code in stocks:
        if not check_data_sufficiency(ts_code, trade_date_str):
            skipped_count += 1
            continue

        latest_payload = broker.fetch_latest(
            ts_code,
            trade_date_str,
            fetch_fields,
            auto_refresh=False,
        )

        if not latest_payload:
            skipped_count += 1
            continue

        if valuation_filters:
            passes_filters = True
            for key, bounds in valuation_filters.items():
                field_name = valuation_field_map.get(key)
                if not field_name:
                    continue
                raw_value = latest_payload.get(field_name)
                if raw_value is None:
                    passes_filters = False
                    break
                try:
                    metric_value = float(raw_value)
                except (TypeError, ValueError):
                    passes_filters = False
                    break
                if key == "total_mv_billion":
                    metric_value = metric_value / 10000.0
                lower, upper = bounds
                if metric_value < lower or metric_value > upper:
                    passes_filters = False
                    break
            if not passes_filters:
                skipped_count += 1
                continue

        factor_values: List[float] = []
        missing = False
        for field in factor_fields:
            value = latest_payload.get(field)
            if value is None:
                missing = True
                break
            try:
                factor_values.append(float(value))
            except (TypeError, ValueError):
                missing = True
                break

        if missing or len(factor_values) != len(factors):
            skipped_count += 1
            continue

        info = broker.get_stock_info(ts_code, trade_date_str)
        if not info:
            skipped_count += 1
            continue

        score = float(np.dot(factor_values, weights))
        evaluated_count += 1
        
        results.append({
            "股票代码": ts_code,
            "股票名称": info.get("name", ""),
            "行业": info.get("industry", ""),
            "综合评分": score,
        })
        
    # 记录评估完成信息
    LOGGER.info(
        "股票评估完成 总股票数=%d 已评估=%d 跳过=%d 结果数=%d",
        len(stocks),
        evaluated_count,
        skipped_count,
        len(results),
        extra=LOG_EXTRA
    )
    
    return results


def _add_to_stock_pool(
    score_df: pd.DataFrame,
    eval_date: date
) -> None:
    """将股票评分结果写入投资池。"""

    broker = DataBroker()
    trade_date = eval_date.strftime("%Y%m%d")
    payload: List[tuple] = []
    ranked_df = score_df.reset_index(drop=True)

    for rank, row in ranked_df.iterrows():
        tags = json.dumps(["stock_evaluation", "top20"], ensure_ascii=False)
        metadata = json.dumps(
            {
                "source": "stock_evaluation",
                "rank": rank + 1,
                "score": float(row["综合评分"]),
            },
            ensure_ascii=False,
        )
        # 获取股票基本信息
        stock_info = broker.get_stock_info(row["股票代码"], trade_date)
        stock_name = stock_info.get("name", "") if stock_info else ""
        stock_industry = stock_info.get("industry", "") if stock_info else ""
        
        payload.append(
            (
                trade_date,
                row["股票代码"],
                float(row["综合评分"]),
                "candidate",
                "factor_evaluation_top20",
                tags,
                metadata,
                stock_name,
                stock_industry,
            )
        )

    with db_session() as conn:
        _ensure_investment_pool_schema(conn)
        conn.execute("DELETE FROM investment_pool WHERE trade_date = ?", (trade_date,))
        if payload:
            conn.executemany(
                """
                INSERT INTO investment_pool (
                    trade_date,
                    ts_code,
                    score,
                    status,
                    rationale,
                    tags,
                    metadata,
                    name,
                    industry
                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
                """,
                payload,
            )
